from typing import List, Dict, Any, Union
from PIL import Image
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

class Message:
    def __init__(self, role: str, content: Union[str, List[Dict[str, Any]]]):
        self.role = role
        self.content = content

    def to_dict(self) -> Dict[str, Any]:
        if isinstance(self.content, str):
            return {"role": self.role, "content": [{"type": "text", "text": self.content}]}
        return {"role": self.role, "content": self.content}

class QwenChatModel:
    def __init__(self,
                 model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct",
                 temperature: float = 0.0,
                 max_new_tokens: int = 512):
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(model_name)

    def _build_inputs(self, messages: List[Message]):
        chat_format = [m.to_dict() for m in messages]
        text = self.processor.apply_chat_template(chat_format, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(chat_format)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            return_tensors="pt",
            padding=True
        )
        return inputs.to(self.model.device)

    async def ainvoke(self, messages: List[Message], **kwargs) -> str:
        inputs = self._build_inputs(messages)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=self.max_new_tokens,
            do_sample=False,
            temperature=self.temperature,
        )
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs)
        ]
        output_texts = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return output_texts[0]

    async def ainvoke_full(self, messages: List[Message], **kwargs):
        return await self.ainvoke(messages)

    async def aclose(self):
        pass
